import copy
import logging

import torch
import yaml
from gxl_ai_utils.config.gxl_config import GxlNode
from torch.utils.data import DataLoader

from wenet.dataset.dataset import Dataset
from wenet.k2.model import K2Model
from wenet.transformer.asr_model import ASRModel
from wenet.utils.ctc_utils import get_blank_id
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer

args = GxlNode(dict(
    modes='hlg_onebest',
    test_data='/home/work_nfs8/xlgeng/new_workspace/wenet_gxl_en_cn_2/examples/onlinesys/data/test/aishell1/shard.list',
    data_type='shard',
    checkpoint='/home/work_nfs8/xlgeng/new_workspace/wenet_gxl_en_cn_2/examples/onlinesys/exp/final_work_stage6_gxl/step_14500.pt',
    beam_size=10,
    batch_size=16,
    blank_penalty=0.0,
    result_dir='./output/aishell1',
    ctc_weight=0.5,
    word='/home/work_nfs8/xlgeng/new_workspace/gxl_ai_utils/eggs/cats_and_dogs/prepare_data_for_en_cn/make_goutu/output_data/words.txt',
    hlg='/home/work_nfs8/xlgeng/new_workspace/gxl_ai_utils/eggs/cats_and_dogs/prepare_data_for_en_cn/make_goutu/output_data/HLG.pt',
    lm_scale=0.7,
    decoder_scale=0.1,
    r_decoder_scale=0.7,
    config='exp/final_work_stage6_gxl/train.yaml',
))
logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s %(levelname)s %(message)s')
with open(args.config, 'r') as fin:
    configs = yaml.load(fin, Loader=yaml.FullLoader)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['filter_conf']['token_max_length'] = 102400
test_conf['filter_conf']['token_min_length'] = 0
test_conf['filter_conf']['max_output_input_ratio'] = 102400
test_conf['filter_conf']['min_output_input_ratio'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['spec_sub'] = False
test_conf['spec_trim'] = False
test_conf['shuffle'] = False
test_conf['sort'] = False
if 'fbank_conf' in test_conf:
    test_conf['fbank_conf']['dither'] = 0.0
elif 'mfcc_conf' in test_conf:
    test_conf['mfcc_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = args.batch_size

tokenizer = init_tokenizer(configs)
test_dataset = Dataset(args.data_type,
                       args.test_data,
                       tokenizer,
                       test_conf,
                       partition=False)

test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

model, configs = init_model(args, configs)
_, blank_id = get_blank_id(configs, tokenizer.symbol_table)
logging.info("blank_id is {}".format(blank_id))
symbol_table = tokenizer.symbol_table
device = torch.device('cuda')
with torch.no_grad():
    for batch_idx, batch in enumerate(test_data_loader):
        logging.info('耿雪龙： 哈哈哈哈哈哈，batch_idx = {}'.format(batch_idx))
        keys = batch["keys"]
        feats = batch["feats"].to(device)
        target = batch["target"].to(device)
        feats_lengths = batch["feats_lengths"].to(device)
        target_lengths = batch["target_lengths"].to(device)
        assert isinstance(model, K2Model)
        logging.info('耿雪龙： 哈哈哈哈哈哈，model.hlg_rescore, batch_idx = {}'.format(batch_idx))
        res = model.hello_gxl(
            speech=feats,
            speech_lengths=feats_lengths,
            decoding_chunk_size=-1,
            num_decoding_left_chunks=-1,
            simulate_streaming=False,
            hlg=args.hlg,
            word=args.word,
            symbol_table=symbol_table,
        )
        logging.info(res)



try:
    print('hhh')
except Exception as e:
    print(e)
